{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 1.6411768e-02,  2.0951958e-02, -8.3168052e-05, -3.0546727e-02],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, done, _, info = self.env.step(action)\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "MyWrapper().reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "消耗时间= 6.74487829208374\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/stable_baselines3/common/evaluation.py:67: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(436.45, 356.1218155350778)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "import time\n",
    "\n",
    "\n",
    "def test_multiple_env(dumm, N):\n",
    "    if dumm:\n",
    "        #DummyVecEnv,在单线程中运行多个环境\n",
    "        env = DummyVecEnv([MyWrapper] * N)\n",
    "    else:\n",
    "        #SubprocVecEnv,在多线程中运行多个环境\n",
    "        env = SubprocVecEnv([MyWrapper] * N, start_method='fork')\n",
    "\n",
    "    start = time.time()\n",
    "\n",
    "    #训练一个模型\n",
    "    model = PPO('MlpPolicy', env, verbose=0).learn(total_timesteps=5000)\n",
    "\n",
    "    print('消耗时间=', time.time() - start)\n",
    "\n",
    "    #关闭环境\n",
    "    env.close()\n",
    "\n",
    "    #测试\n",
    "    return evaluate_policy(model, MyWrapper(), n_eval_episodes=20)\n",
    "\n",
    "\n",
    "test_multiple_env(dumm=True, N=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "消耗时间= 12.886700868606567\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(337.0, 119.68667427913601)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_multiple_env(dumm=True, N=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "消耗时间= 9.097013473510742\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(371.3, 279.9510850130787)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_multiple_env(dumm=False, N=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "消耗时间= 14.181455612182617\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(481.85, 260.54352323556236)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_multiple_env(dumm=False, N=10)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "3_multiprocessing.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
